# PL/0 compiler

# PL0 Scanner

# Initialize the scanner by opening the input source file.
def scannerInitialize (fname):
    global fp
    keywords["CONST"] = constSy
    keywords["VAR"] = varSy
    keywords["PROCEDURE"] = functionSy
    keywords["BEGIN"] = beginSy
    keywords["END"] = endSy
    keywords["IF"] = ifSy
    keywords["THEN"] = thenSy
    keywords["WHILE"] = whileSy
    keywords["DO"] = doSy
    keywords["CALL"] = returnSy
    keywords["WRITE"] = writeSy
    keywords["ODD"] = oddSy
    keywords["CALL"] = callSy   
    try:
        fp = open (fname, "r")
    except:
        print ("No such file. ");
        exit()
    print ("File "+fname+" is open.")
    
# Return the next character in the input stream
def nextCh ():
    global ch,eof
    
    if len(ch)==0: return eof
    try:
        ch = fp.read(1)
        return ch[0]
    except IOError:
        print('An error occured trying to read the file.')    
    except ValueError:
        print('Non-numeric data found in the file.')
    except ImportError:
        print ("NO module found")    
    except EOFError:
        print('Why did you do an EOF on me?')
        return eof
    except KeyboardInterrupt:
        print('You cancelled the operation.')
    except:
        print('An error occured.')
    return ch

def nextSy ():
    global ch, sy, ident, numberVal, keywords
    
#    if sy == finalSy: return finalSy
    if len(ch) <= 0: return finalSy
    
    if ch == " " or ch == "\n":
        while ch == " " or ch == "\n": nextCh()
        
    if letter(ch):
        scanIdent()
        try:
            k = keywords[ident]
            return k
        except:
            return identSy
        
    if digit(ch):
        scanNumber()
        return numberSy
    if ch == "+":
        nextCh()
        return plusSy
    elif ch == "-":
        nextCh()
        return minusSy
    elif ch == "*":
        nextCh()
        return multSy
    elif ch == "/":
        nextCh()
        return divideSy
    elif ch == "=":
        nextCh()
        return equalSy
    elif ch == "<":
        nextCh()
        if ch == "=":
            nextCh()
            return lesseqSy
        elif ch == ">":
            nextCh()
            return noteqSy
        else:
            return lessSy
    elif ch == ">":
        nextCh()
        if ch == "=":
            nextCh()
            return greatereqSy
        else:
            return greaterSy
    elif ch == "#":
        nextCh()
        return noteqSy
    elif ch == "(":
        nextCh ()
        return lparenSy
    elif ch == ")":
        nextCh()
        return rparenSy
    
    elif ch == ":":
        nextCh()
        if ch == "=":
            nextCh()
            return assignSy
        else:
            return errorSy
    
    elif ch == ";":
        nextCh()
        return semiSy
    elif ch == ",":
        nextCh()
        return commaSy
    elif ch == ".":
        nextCh()
        return periodSy
    elif ch == "!":
        nextCh()
        return bangSy
    elif ch == "?":
        nextCh()
        return questSy
    else:
        print ("Bad symbol ", ch)
        return finalSy
    
def digit (c):
    if c>="0" and c<="9": return True
    return False

def digitVal (c):
    return ord(c) - ord("0")

def letter (c):
    if c>="A" and c<="Z": return True
    if c>="a" and c<="z": return True
    return False

def identChar(c):
    if c>="0" and c<="9": return True
    if c>="A" and c<="Z": return True
    if c>="a" and c<="z": return True
    if c == "_": return True
    return False
    
def scanNumber():
    global numberVal
    numberVal = 0
    while digit(ch):
        numberVal = numberVal*10 + digitVal(ch)
        nextCh()
        
def scanIdent ():
    global ident
    ident = ""
    while identChar(ch):
        ident = ident + ch
        nextCh()

def outSy (ss):
    if ss == numberSy:
        return "Number "+str(numberVal)
    if ss == plusSy: return " + "
    if ss == minusSy: return " - "
    if ss == multSy: return " * "
    if ss == divideSy: return " / "
    if ss == equalSy: return " = "
    if ss == lesseqSy: return " <= "
    if ss == noteqSy: return " != "
    if ss == lessSy: return " < "
    if ss == greatereqSy: return " >= "
    if ss == greaterSy: return " > "
    if ss == lparenSy: return " ( "
    if ss == rparenSy: return " ) "
    if ss == semiSy: return " ; "
    if ss == commaSy: return " , "
    if ss == periodSy: return " . "
    if ss == identSy: return " Identifier "+ident
    if ss == bangSy: return " ! "
    if ss == questSy: return " read "
    if ss == constSy: return " CONST "
    if ss == varSy: return " VAR "
    if ss == functionSy: return " PROCEDURE "
    if ss == beginSy: return " BEGIN "
    if ss == endSy: return " END "
    if ss == ifSy: return " IF "
    if ss == thenSy: return " THEN "
    if ss == whileSy: return " WHILE "
    if ss == doSy: return " DO "
    if ss == returnSy: return " RETURN "
    if ss == writeSy: return " WRITE "
    if ss == oddSy: return " ODD "
    if ss == callSy: return " CALL "
    if ss == finalSy: return " <FINAL> "
    if ss == noSy: return " <NONE> "
    if ss == errorSy: return " <Error> "
    if ss == assignSy: return " := "
    return ""

# PL/0 Symbol table

def defineIdent (id):
    global symbols, nextVariable, lexLevel
    try:
        k = symbols[lexLevel][id]
        print ("Error: Duplicate definition of ", id, k)
        return
    except:
        symbols[lexLevel][id] = nextVariable
        nextVariable = nextVariable + 1;
    
def convertIdent (id):
    for i in range(4, -1, -1):
        if id in symbols[i]:
            return "var"+str(symbols[i][id])+"_"+str(i)
    return "not found"
        
def getIdent (id):
    lev = 4
    while lev>=0:
        try:
            k = symbols[lev][id]
            print ("Ident ", id, " is ", k)
            return k
        except:
            lex = lev - 1
    print ("Identifier not found: ", id)
    return 9999

def lexBack():
    global lexLevel
    symbols[lexLevel] = {}

# PL/1 Parser

def program ():
    global sy

    gen1 ("// Main. ")
    genLibrary()
    if sy == constSy:
        constStmt()
    if sy == varSy:
        varStmt()
    while sy == functionSy:
        functionStmt()
        if sy == semiSy:
            sy = nextSy()

    gen1 ("int main()")
    gen1 ("{")
    main = True
    statement()
    gen1 ("}")
    if sy == periodSy:
        allDone()
    else:
        print ("Error on final symbol.")

        
def block ():
    global sy
    gen1 ("// Block. ")
    if sy == constSy:
        constStmt()
    if sy == varSy:
        varStmt()
    while sy == functionSy:
        functionStmt()
        if sy == semiSy:
            sy = nextSy()
    statement()

def allDone ():
    print ("//Program complete.", file=outf)
    
def constStmt ():
    global sy
    sy = nextSy()
    while sy == identSy:
        defineIdent (ident)
        gen3n ("  int ", convertIdent(ident), " = ")
        sy = nextSy()
        if sy == eqsy:
            sy = nextSy()
        else:
            print ("error: = expected.")
        if sy == numberSy:
            gen2 (str(numberValue), ";")
            sy = nextSy()
        else:
            print ("Numerical constant expected.")
        if sy == commaSy:
            sy = nextSy()
        elif sy == semiSy:
            sy = nextSy()
            break
             
def varStmt ():
    global sy
    sy = nextSy()
    gen1n ("  int ")
    while sy == identSy:
        defineIdent (ident)
        gen1n (convertIdent(ident))
        sy = nextSy()
        if sy == commaSy:
            sy = nextSy()
            gen1n (", ")
    if sy == semiSy:
        sy = nextSy()
    gen1 (";")

def functionStmt ():
    global sy, lexLevel

    sy = nextSy()
    if sy == identSy:
        defineIdent (ident)
        gen3 (" void ", convertIdent(ident), "(){")
        sy = nextSy()
    else:
        print ("Missing ident")
    if sy==semiSy:
        sy = nextSy()
    else:
        print ("Missing semicolon ", outSy(sy))
    lexLevel = lexLevel + 1
    block()
    if sy == semiSy:
        sy = nextSy();
    else:
        print ("Missing semicolon")
    gen1 ("}")
    lexBack()
    lexLevel = lexLevel -1

def callStmt():
    global sy
    sy = nextSy()
    gen3 ("", convertIdent(ident),  "();")
    sy = nextSy()
#    if sy == semiSy:
#        sy = nextSy();
    
def inputStmt():
    global sy

    sy = nextSy()          # Skip '?'
    if sy == identSy:      # Should be a variable
        gen2 (convertIdent(ident), " = read();")
        sy = nextSy()
    else:
        print("Error: READ expects a variable.")
    print ("After READ ", outSy(sy))

        
def outputStmt ():
    global sy
    sy = nextSy()
    gen3 ("println (", convertIdent(ident), ");")
    sy = nextSy()
#    if sy == semiSy:
#        sy = nextSy();
        
def statement ():
    global sy

    if sy == callSy:
        callStmt()
    elif sy == questSy:
        inputStmt()
    elif sy == bangSy:
        outputStmt()
    elif sy == beginSy:
        sy = nextSy()
        statement()
        while sy == semiSy:
            sy = nextSy()
            statement()
        if sy == endSy:
            sy = nextSy()
        else:
            print ("Error: No END for BEGIN.", outSy(sy))
    elif sy == ifSy:
        gen1n ("  if (")
        sy = nextSy()
        condition()
        if sy == thenSy:
            sy = nextSy()
            gen1 (") {")
            statement()
            gen1 ("}")
        else:
            print ("Missing THEN.")
    elif sy == whileSy:
        gen1n ("while (")
        sy = nextSy()
        condition()
        gen1 (") {")
        if sy == doSy:
            sy = nextSy()
            statement()
            gen1 ("}")
    elif sy == identSy:
        gen2n (convertIdent(ident), " = ")
        sy = nextSy()
        if sy == assignSy:
            sy = nextSy()
        expression()
        gen1 (";")
    else:        
        print ("Syntax error in statement.", sy)
            
def condition():
    global sy
    if sy == oddSy:
        sy = nextSy()
        gen1n ("odd( ")
        expression()
        gen1n (") ==1")
    else:
        expression()
        while sy==equalSy or sy==lessSy or sy == lesseqSy or sy==greaterSy or sy==greatereqSy or sy==noteqSy:
            gen1n (outSy(sy))
            sy = nextSy()
            expression()

def outx (s):
    global indent
    for i in range(0, indent):
        print ("  ", end="")
    print (s, outSy(sy))
    
def expression ():
    global sy, indent
    xtra = False
    
    outx('expression')
    indent +=1
    if sy==plusSy:
        sy = nextSy()
        gen1n ("(")
        xtra = True
    if  sy==minusSy:
        neg = True
        sy = nextSy()
        gen1n ("-( ")
        xtra = True
    term()
    while sy == plusSy or sy == minusSy:
        gen1n (outSy(sy))
        outx("e           ")
        sy = nextSy()
        term()
    if (xtra): gen1n(")")
    indent -= 1

            
def term ():
    global sy, indent

    outx ("term")
    indent += 1
    gen1n ("(")
    factor()
    if sy==multSy or sy==divideSy:
        while sy==multSy or sy==divideSy:
            outx("t           ")
            gen1n (outSy(sy))
            sy = nextSy()
            factor()
    gen1n (")")
    indent -= 1
            
def factor ():
    global sy, indent
    
    outx ("Factor ")
    indent += 1
    gen1n ("(")
    if sy == identSy:
        gen1n (convertIdent(ident))
        sy = nextSy()
    elif sy == numberSy:
        gen1n (str(numberVal))
        sy = nextSy()
    elif sy == lparenSy:
        gen1n (" ( ")
        sy = nextSy()
        expression()
        if (sy == rparenSy):
            sy = nextSy()
            gen1n (" ) ")
            outx ("end (exp)")
        else:
            print("Possible missing ')'")
    else:
        print ("Factor error.");
    gen1n (")")
    indent-= 1
        
# Code generation
def gen1n (s):
    global lexLevel
    print (s, end="", file=outf)

def gen1 (s):
    global lexLevel
    print (s, file=outf)


def gen2n (a, b):
    global lexLevel
    print (a, b, file=outf)
def gen2 (a, b):
    global lexLevel
    print (a, b, file=outf)

def gen3 (a, b, c):
    print (a, b, c, file=outf)
def gen3n (a, b, c):
    print (a, b, c, end="", file=outf)

def genLibrary():
    print("#include <stdio.h>", file=outf)
    print("int odd(int x){ return x%2; }", file=outf)
    print("int read ()", file=outf) 
    print("{", file=outf)
    print("  int i;", file=outf)
    print('  scanf("Input: ", &i);', file=outf)
    print("  return i;", file=outf)
    print("}", file=outf)
    print('void println (int i) {  printf ("output :%d\\n ", i); }', file=outf)


fp = None   # Input text file
ch = chr(1)
eof = chr(0)

numberSy    = 101
plusSy      = 102
minusSy     = 103
multSy      = 104
divideSy    = 105
equalSy     = 106
lesseqSy    = 107
noteqSy     = 108
lessSy      = 109
greatereqSy = 110
greaterSy   = 111
lparenSy    = 112
rparenSy    = 113
semiSy      = 114
commaSy     = 115
periodSy    = 116
identSy     = 117
bangSy      = 118
questSy     = 119
constSy     = 120
varSy       = 121
functionSy  = 122
beginSy     = 123
endSy       = 124
ifSy        = 125
thenSy      = 126
whileSy     = 127
doSy        = 128
returnSy    = 129
writeSy     = 130
oddSy       = 131
callSy      = 132
assignSy    = 133
finalSy     = 999
noSy        = 998
errorSy     = 997

ident = ""
numberVal = 0
keywords = {}
sy = noSy
nextVariable = 0
symbols = [{}, {}, {}, {}, {}]
main = False
lexLevel = 0
outf = open ("output.c", "w")
indent = 0

scannerInitialize ("square.pl0.txt")
listing = open ("list.txt", "w")
nextCh()
sy = nextSy()
program()

outf.close()
fp.close ()
listing.close()

